import json
import os
from threading import Thread
import time
from openai import OpenAI
import openai
import requests
import tiktoken
import torch
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, AutoModel

from .config import BaseConfig, ConfigCallAPI, ConfigCallGLM
from .utils import read_api_keys_from_file, save_file, get_unique_id
from langchain_openai import ChatOpenAI
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferMemory
# from langchain.llms import OpenAI


class StopOnTokens(StoppingCriteria):
    from torch import LongTensor, FloatTensor

    def __init__(self, eos_token_id) -> None:
        super().__init__()
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: LongTensor, scores: FloatTensor, **kwargs) -> bool:
        stop_ids = self.eos_token_id
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


class CallGLM:
    def __init__(self, config: ConfigCallGLM):
        self.model_path = config.model_path
        self.model_name = config.model_name

        # 读取本地模型
        if self.model_path is None or self.model_path == "" or os.path.exists(self.model_path) == False:
            raise ValueError(f"[FBI Warning] 本地部署模型路径有误！{self.model_path}不存在")
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path, trust_remote_code=True, encode_special_tokens=True)
            self.model = AutoModel.from_pretrained(
                self.model_path, trust_remote_code=True, device_map="auto").eval()
            # 下面的代码也可以移到config里
            self.eos_token_id = config.eos_token_id
            self.stopontokens = StopOnTokens(config.eos_token_id)

    def get_response(self, system_prompt, user_input, history=None, **kwargs):
        if history is None:
            history = []
        history.append(user_input)
        stop = self.stopontokens
        if len(history) % 2 == 0:
            history.append(user_input)
        messages = []
        messages.append({"role": "system", "content": system_prompt})
        for index, message in enumerate(history):
            if index % 2 == 0:
                role = "user"
            else:
                role = "assistant"
            messages.append({"role": role, "content": message})
        model_inputs = self.tokenizer.apply_chat_template(messages,
                                                          add_generation_prompt=True, tokenize=True, return_tensors="pt"
                                                          ).to(self.model.device)
        streamer = TextIteratorStreamer(tokenizer=self.tokenizer, timeout=60,
                                        skip_prompt=True, skip_special_tokens=True)
        generate_kwargs = {
            "input_ids": model_inputs,
            "streamer": streamer,
            "stopping_criteria": StoppingCriteriaList([stop]),
            "do_sample": True,
            "eos_token_id": self.eos_token_id,
        }
        # "max_new_tokens": max_new_length,
        # "top_p": top_p,
        # "temperature": temperature,
        # "repetition_penalty": 1.2,
        kwargs.update(generate_kwargs)
        t = Thread(target=self.model.generate, kwargs=kwargs)
        t.start()

        response = ""
        for new_token in streamer:
            if new_token:
                response += new_token
        response = response.strip()

        # 把GPU里的东西手动清一下
        del stop
        del model_inputs
        del streamer
        del generate_kwargs
        del t
        torch.cuda.empty_cache()

        return response

    # def chat(self) 创建一个成员变量history维护聊天记录，调用get_response实现


class CallAPIModel:
    def __init__(self, config: ConfigCallAPI):
        self.config = config
        self.api_keys = read_api_keys_from_file(config.api_path)
        self.base_url = config.base_url
        self.model_name = config.model_name
        self.max_retries = config.max_retries
        self.retry_delay = config.retry_delay
        self.use_openai_api = config.use_openai_api
        self.initialized = False

        if not self.api_keys:
            raise ValueError("[FBI Warning] API key is missing!")

        self.current_key_index = 0

    def get_api_key(self):
        api_key = self.api_keys[self.current_key_index % len(self.api_keys)]
        self.current_key_index += 1
        # print(api_key)
        return api_key

    def get_response(self, system_prompt, user_input, history=None, **kwargs):
        if history is None:
            history = []
        # history.append(user_input)

        messages = [{"role": "system", "content": system_prompt}]
        for index, message in enumerate(history):
            role = "user" if index % 2 == 0 else "assistant"
            messages.append({"role": role, "content": message})
        messages.append({"role": "user", "content": user_input})

        if self.use_openai_api:
            response = self._call_openai_api(messages, **kwargs)
        else:
            response = self._call_request_api(messages, **kwargs)
        return response

    def _call_request_api(self, messages, **kwargs):
        url = self.base_url
        # kwargs['max_tokens'] = 16000 #16384
        kwargs['model'] = self.model_name
        # kwargs['top_p'] = 1
        # kwargs['presence_penalty'] = 1
        kwargs['messages'] = messages
        api_key = self.get_api_key()
        headers = {
            "Content-Type": "application/json",
            "Authorization": api_key
        }
        for attempt in range(self.max_retries):
            result=None
            try:
                response = requests.post(url, headers=headers, data=json.dumps(kwargs).encode('utf-8') )
                response.raise_for_status()
                result=response.json()['choices'][0]['message']['content'].strip()
            except requests.exceptions.ConnectTimeout as e:
                print(f"API Connect Timeout, attempt {attempt} failed: {e}")
            except requests.exceptions.ReadTimeout as e:
                print(f"API Read Timeout, attempt {attempt} failed: {e}")
            except requests.RequestException as e:
                print(f"Request Error, attempt {attempt} failed: {e}")
            if result:
                return result
            time.sleep(self.retry_delay)

        raise Exception(f"Max retries {self.max_retries} exceeded")

    def _call_openai_api(self, messages, **kwargs):
        for attempt in range(self.max_retries):
            try:
                api_key = self.get_api_key()
                client = OpenAI(api_key=api_key, base_url=self.base_url)
                response = client.chat.completions.create(
                    model=self.model_name,
                    messages=messages,
                    **kwargs
                )
                # 或者两个调用类的get_response统一返回list
                # response.choices[0].message['content'].strip()
                return response.choices[0].message.content

            except openai.APIConnectionError as e:
                # Handle connection error here
                print(f"Failed to connect to API: {e}")
                print(f"Attempt {attempt + 1} failed: {e}")
            except openai.RateLimitError as e:
                # Handle rate limit error (we recommend using exponential backoff)
                print(f"API request exceeded rate limit: {e}")
                print(f"Attempt {attempt + 1} failed: {e}")
            except openai.APIError as e:
                # Handle API error here, e.g. retry or log
                print(f"API returned an API Error: {e}")
                print(f"Attempt {attempt + 1} failed: {e}")
            time.sleep(self.retry_delay)

        raise Exception(f"Max retries {self.max_retries} exceeded")

    # Langchain对话链方法：
    def initiate_chat(self, system_prompt):
        # if not hasattr(self, 'initialized') or not self.initialized:
        if not self.initialized:
            # 创建对话内存，用于存储对话历史
            memory = ConversationBufferMemory()
            api_key = self.get_api_key()
            # self.client = ChatOpenAI(model=self.model_name,api_key=api_key, base_url=self.base_url)
            self.client = OpenAI(api_key=api_key, base_url=self.base_url)
            # 创建对话链
            self.conversation_chain = ConversationChain(
                llm=self.client, memory=memory)

            self.conversation_chain.append(
                {"role": "system", "content": system_prompt})
            self.initialized = True

    def chat(self, user_input, temperature=0.7, max_tokens=2000, top_p=1.0, frequency_penalty=0.0, presence_penalty=0.0):
        self.conversation_chain.append({"role": "user", "content": user_input})

        # 使用对话链生成响应，传递超参数
        response = self.conversation_chain.predict(
            input=user_input,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty
        )

        # 添加 assistant 的响应到对话链中
        self.conversation_chain.append(
            {"role": "assistant", "content": response})

        print(response)

        return response


# def get_config_class(model_name='gpt-4o', **kwargs):
#     if 'glm' in model_name:
#         return ConfigCallGLM(model_name, **kwargs)
#     elif 'gpt' in model_name:
#         return ConfigCallAPI(model_name, api_path="./api_key_files/openai_api_key.txt", **kwargs)
#     elif 'qwen' in model_name:
#         return ConfigCallAPI(model_name, api_path="./api_key_files/aliyun_api_key.txt",
#                              base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
#     else:
#         raise ValueError("还未集成该模型")


class CallModel:
    def __init__(self, config: BaseConfig):
        self.config = config
        model_name = config.model_name
        self.model_name = model_name
        if 'glm' in model_name:
            self.model = CallGLM(config)
            self.max_length_key = 'max_new_tokens'
        elif 'gpt' in model_name or 'qwen' in model_name:
            self.model = CallAPIModel(config)
            self.max_length_key = 'max_tokens'
        elif 'claude-3' in model_name:
            self.model = CallAPIModel(config)
            self.max_length_key = 'max_tokens'
        else:
            raise ValueError("还未集成该模型")
        self.max_input_length, self.max_new_length = self.set_length_limit(
            model_name)
        self.token_encoder = self.get_encoding(model_name)

    def set_length_limit(self, model_name):
        if model_name == "glm-4-9b-chat":
            return 120000, 8192
        elif "gpt-4o-mini" in model_name:
            return 120000, 16384
        elif "gpt-4o" in model_name:
            return 120000, 4096
        elif "gpt-4-turbo" in model_name:
            return 120000, 4095
        elif "gpt-4" in model_name:
            return 8192, 4096
        elif "claude-3" in model_name:
            return 120000, 4096
        elif model_name == "qwen-max-longcontext":
            return 28000, 2000
        elif model_name == "qwen2-7b-instruct":
            return 120000, 6000
        else:
            return 4096, 2000

    def get_encoding(self, model_name):
        if "glm" in model_name:
            return tiktoken.get_encoding("gpt2")  # 假设 GLM 使用 GPT-2 的编码
        elif "gpt" in model_name:
            # 使用 OpenAI 提供的 encoding_for_model 方法
            return tiktoken.encoding_for_model(model_name)
        elif "qwen" in model_name:
            return tiktoken.get_encoding("cl100k_base")
        elif "claude-3" in model_name:
            return tiktoken.get_encoding("cl100k_base")
        else:
            raise ValueError("Unknown model name for encoding")

    def get_token_length(self, input_str: str):  # 计算token数
        return len(self.token_encoder.encode(input_str))+20  # 计算不一定可靠，加个偏移量更安全

    def contains_NG_Words(self, text):  # 判断是否被大模型拒绝
        NG_Words = ["but I can't", "but I can't",
                    "but I’m a large language model"]
        return any("Sorry, "+word in text or "I'm sorry, "+word in text for word in NG_Words)

    def get_response(self, system_prompt, user_input, task='get_response', history=None, save_dir=None, **kwargs):
        # 获取编码器
        # self.token_encoder = self.get_encoding(self.model_name)

        # 计算总token数
        sys_tokens, user_tokens = self.get_token_length(system_prompt), self.get_token_length(user_input)
        total_tokens = sys_tokens + user_tokens
        his_tokens = 0 if not history else self.get_token_length(str(history))

        if total_tokens > self.max_input_length:
            raise ValueError(f"Input length `{total_tokens}` exceeds the maximum input length `{self.max_input_length}`\
                              for the model")

        max_new_length = kwargs.get(self.max_length_key, None)
        if max_new_length != None and max_new_length > self.max_new_length:
            raise ValueError(f"Overall max new length exceeds the maximum new length `{self.max_new_length}` for the model, \
                             reset the key [{self.max_length_key}]")
        # 获取模型响应
        response = self.model.get_response(system_prompt, user_input, history, **kwargs)
        # response = "This is a placeholder response."
        if save_dir != None:  # 保存本次get_response的输入输出
            kwargs_str = ""
            for key, value in kwargs.items(): kwargs_str += f"{key}={value}\n"
            save_file(
                text=f"""
---------- [Kwargs] ---------- \n\n{kwargs_str}\n
---------- [Input_Tokens] ---------- \n\n{sys_tokens} + {user_tokens} = {sys_tokens + user_tokens}\n
---------- [History_Tokens] ---------- \n\n{his_tokens}\n
---------- [Output_Tokens]---------- \n\n{self.get_token_length(response)}\n
---------- [System_prompt] ---------- \n\n{system_prompt}\n
---------- [User_prompt] ---------- \n\n{user_input}\n
---------- [History] ---------- \n\n{str(history)}\n
---------- [Response] ---------- \n\n{response}\n
""",
                path=save_dir+'/responses/', filename=f"{get_unique_id()}_get_response(task={task})", type='txt'
            )
        return response

    # Langchain对话链方法：
    def initiate_chat(self, system_prompt):
        # if not hasattr(self.model, 'initialized') or not self.model.initialized:
        self.model.initiate_chat(system_prompt)
        # self.initialized = True

    def chat(self, user_input, **kwargs):

        max_new_length = kwargs.get(self.max_length_key, None)
        if max_new_length != None and max_new_length > self.max_new_length:
            raise ValueError(f"Overall max new length exceeds the maximum new length `{self.max_new_length}` for the model, \
                             reset the key [{self.max_length_key}]")

        return self.model.chat(user_input, **kwargs)
